package vm

import (
	"fmt"
	"math/big"
	"testing"

	"github.com/cosmos/evm/contracts"
	testutiltypes "github.com/cosmos/evm/testutil/types"
	evmtypes "github.com/cosmos/evm/x/vm/types"

	"cosmossdk.io/math"

	sdk "github.com/cosmos/cosmos-sdk/types"
)

// TestFlashLoanExploit runs the two flash loan methods
func (s *NestedEVMExtensionCallSuite) TestFlashLoanExploit() {
	testCases := []struct {
		method                  string
		expDelegation           bool
		expSenderERC20Balance   *big.Int
		expContractERC20Balance *big.Int
	}{
		{
			method:                  "flashLoan",
			expDelegation:           true,
			expSenderERC20Balance:   s.mintAmount,
			expContractERC20Balance: big.NewInt(0),
		},
		{
			method:                  "flashLoanWithRevert",
			expDelegation:           false,
			expSenderERC20Balance:   s.delegateAmount,
			expContractERC20Balance: s.delegateAmount,
		},
	}

	for _, tc := range testCases {
		caseName := tc.method
		s.T().Run(caseName, func(t *testing.T) {
			// reset test state
			s.SetupTest()

			// Execute flash loan contract call
			_, err := s.factory.ExecuteContractCall(
				s.deployer.Priv,
				evmtypes.EvmTxArgs{To: &s.flashLoanAddr, GasPrice: big.NewInt(900_000_000), GasLimit: 400_000, Amount: s.delegateAmount},
				testutiltypes.CallArgs{ContractABI: s.flashLoanContract.ABI, MethodName: tc.method, Args: []interface{}{s.erc20Addr, s.validatorToDelegateTo}},
			)
			s.Require().NoError(err, "failed to execute flash loan")
			s.Require().NoError(s.network.NextBlock(), "failed to commit block")

			flashLoanAccAddr := sdk.AccAddress(s.flashLoanAddr.Bytes())

			if tc.expDelegation {
				// Check delegation exists
				delRes, err := s.handler.GetDelegation(flashLoanAccAddr.String(), s.validatorToDelegateTo)
				s.Require().NoError(err, "failed to get delegation")
				delAmtPost := delRes.DelegationResponse.Balance.Amount
				expected := s.delegatedAmountPre.Add(math.NewIntFromBigInt(s.delegateAmount))
				s.Require().Equal(expected, delAmtPost, "delegated amount mismatch")
			} else {
				// Expect no delegation
				_, err := s.handler.GetDelegation(flashLoanAccAddr.String(), s.validatorToDelegateTo)
				s.Require().Error(err, "delegation should not exist")
				s.Require().Contains(err.Error(), fmt.Sprintf("delegation with delegator %s not found for validator %s", flashLoanAccAddr.String(), s.validatorToDelegateTo))
			}

			// Verify deployer ERC20 balance
			res, err := s.factory.ExecuteContractCall(
				s.deployer.Priv,
				evmtypes.EvmTxArgs{To: &s.erc20Addr},
				testutiltypes.CallArgs{ContractABI: contracts.ERC20MinterBurnerDecimalsContract.ABI, MethodName: "balanceOf", Args: []interface{}{s.deployer.Addr}},
			)
			s.Require().NoError(err, "failed to get deployer balance")
			s.Require().NoError(s.network.NextBlock(), "failed to commit block")
			ethRes, err := evmtypes.DecodeTxResponse(res.Data)
			s.Require().NoError(err, "failed to decode balance response")
			unpacked, err := contracts.ERC20MinterBurnerDecimalsContract.ABI.Unpack("balanceOf", ethRes.Ret)
			s.Require().NoError(err, "failed to unpack balance")
			bal, ok := unpacked[0].(*big.Int)
			s.Require().True(ok, "balance is not *big.Int")
			s.Require().Equal(tc.expSenderERC20Balance.String(), bal.String(), "deployer balance mismatch")

			// Verify flash loan contract ERC20 balance
			res2, err := s.factory.ExecuteContractCall(
				s.deployer.Priv,
				evmtypes.EvmTxArgs{To: &s.erc20Addr},
				testutiltypes.CallArgs{ContractABI: contracts.ERC20MinterBurnerDecimalsContract.ABI, MethodName: "balanceOf", Args: []interface{}{s.flashLoanAddr}},
			)
			s.Require().NoError(err, "failed to get contract balance")
			s.Require().NoError(s.network.NextBlock(), "failed to commit block")
			ethRes2, err := evmtypes.DecodeTxResponse(res2.Data)
			s.Require().NoError(err, "failed to decode contract balance response")
			unpacked2, err := contracts.ERC20MinterBurnerDecimalsContract.ABI.Unpack("balanceOf", ethRes2.Ret)
			s.Require().NoError(err, "failed to unpack contract balance")
			bal2, ok := unpacked2[0].(*big.Int)
			s.Require().True(ok, "contract balance is not *big.Int")
			s.Require().Equal(tc.expContractERC20Balance.String(), bal2.String(), "contract balance mismatch")
		})
	}
}
